import pandas as pd
import numpy as np
import time
import cv2
import torch
from torch.autograd import Variable
import lib.utils.utils as utils
import lib.models.crnn as crnn
import lib.config.alphabets as alphabets
import yaml
from easydict import EasyDict as edict
import argparse
import os
from torchvision import transforms
from PIL import Image


def parse_arg():
    parser = argparse.ArgumentParser(description="demo")

    parser.add_argument('--cfg', help='experiment configuration filename', type=str,
                        default='lib/config/OWN_config.yaml')
    parser.add_argument('--image_path', type=str, default='../data/test/', help='the path to your image')
    parser.add_argument('--checkpoint', type=str,
                        default='output/OWN/crnn/2020-06-09-21-02/checkpoints/checkpoint_64_acc_0.7318.pth',
                        help='the path to your checkpoints')

    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        config = yaml.load(f)
        config = edict(config)

    config.DATASET.ALPHABETS = alphabets.alphabet
    config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)

    return config, args


def recognition(config, img, model, converter, device):
    # ratio resize
    '''
    w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W))
    h, w = img.shape
    img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
    img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1))

    # normalize
    img = img.astype(np.float32)
    img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
    img = img.transpose([2, 0, 1])
    img = torch.from_numpy(img)
    '''
    img = img.to(device)
    img = img.view(1, *img.size())
    model.eval()
    preds = model(img)

    _, preds = preds.max(2)
    preds = preds.transpose(1, 0).contiguous().view(-1)

    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
    # print('results: {0}'.format(sim_pred))
    return sim_pred


if __name__ == '__main__':

    config, args = parse_arg()
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    ptransform = transforms.Compose([
        transforms.CenterCrop((70, 200)),
        transforms.Resize([32, 200]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
    )
    model = crnn.get_crnn(config).to(device)
    print('loading pretrained model from {0}'.format(args.checkpoint))
    model.load_state_dict(torch.load(args.checkpoint)['state_dict'])
    result = []
    num = []
    # started = time.time()
    for img_path in os.listdir(args.image_path):
        started = time.time()
        # img = cv2.imread(args.image_path + img_path)
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img = Image.open(args.image_path + img_path).convert('RGB')
        img = ptransform(img)
        # img.sub_(0.5).div_(0.5)
        converter = utils.strLabelConverter(config.DATASET.ALPHABETS)

        pred = recognition(config, img, model, converter, device)
        result.append(pred)
        num.append(img_path[:-4])
        finished = time.time()
        # print('elapsed time: {0}'.format(finished - started))
test_csv = pd.DataFrame()
test_csv[0] = num
test_csv[1] = result
test_csv.to_csv('submission.csv', index=None, header=None)
